from tqdm import tqdm
import argparse
import torch
import time
import sys
import os

from utils import load_model, load_resize_image, negative_prompt_inversion, generate
from null_text_inversion import NullInversion

device = torch.device("cuda")


def main(args):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

    model  = load_model()

    # Load captions and image filenames to use the experiments
    with open("data/captions.txt", "r") as f:
        data = f.read().split("\n")[:-1]
        data = [d.split(".jpg,") for d in data]

    # Set the output directory
    if args.npi:
        out_dir = f"results/rec{args.step:03}"
    elif args.nti:
        out_dir = f"results/rec_nlt{args.step:03}"
    else:
        out_dir = f"results/rec_org{args.step:03}"
    os.makedirs(out_dir, exist_ok=True)

    if args.nti:
        null_inversion = NullInversion(model, args.step, args.cfg)

    times = []
    for name, caption in tqdm(data):
        # Load image
        input_image = load_resize_image(f"data/val2017/{name}.jpg")

        # Calculate processing time
        t0 = time.time()
        if args.nti:
            x_T, uncond_embed = null_inversion.invert(input_image, caption)
        else:
            x_T, uncond_embed = negative_prompt_inversion(model, input_image, caption, args.step)
        times.append(time.time() - t0)

        # When using neither of null-text inversion nor negative-prompt inversion
        # replace the unconditional embedding with the null-text
        if not (args.nti or args.npi):
            with torch.no_grad():
                uncond_embed = model._encode_prompt("", device, num_images_per_prompt=1, do_classifier_free_guidance=False, negative_prompt=None).detach()

        recons_image = generate(model, caption, uncond_embed, x_T, args.step, args.cfg)

        # Save the reconstruction image
        recons_image.save(os.path.join(out_dir, name+".png"))

    # Output processing time recording
    with open(os.path.join(out_dir, "times.txt"), "w") as f:
        f.write("\n".join(list(map(str, times))))


if __name__ == "__main__":
    parser = argparse.ArgumentParser("image reconstruction")

    parser.add_argument("--step", type=int, default=50, help="Number of steps to generate")
    parser.add_argument("--cfg", type=float, default=7.5, help="Classifier-free Guidance scale")

    parser.add_argument("--npi", action="store_true", help="Use negative-prompt inversion")
    parser.add_argument("--nti", action="store_true", help="Use null-text inversion")

    args = parser.parse_args()

    if args.npi and args.nti:
        print("Only one of '--npi' and '--nti' can be used")
        sys.quit()

    main(args)
